/*
 Copyright 2013--Present JMM_PROGNAME
 
 This file is distributed under the terms of the JMM_PROGNAME License.
 
 You should have received a copy of the JMM_PROGNAME License.
 If not, see <JMM_PROGNAME WEBSITE>.
*/
// CREATED    : 7/20/2015
// LAST UPDATE: 10/1/2015

#include "statsxx/optimization/gradient_based/RPROP.hpp"

// STL
#include <algorithm>  // std::max()
#include <cstdlib>    // std::abs()
#include <functional> // std::function<>
#include <tuple>      // std::tuple<>, std::make_tuple()

// jScience
#include "jmath_basic.h" // sign()


inline RPROP::RPROP()
{
    this->eta_p = 1.1;
    this->eta_m = 0.9;
    this->fdelta = 0.1;
//    this->delta_min = 0.0;
//    this->delta_max = 0.1;
    
    this->ftol = 1.0e-4;
    this->iter_max = 100;
}

inline RPROP::~RPROP() {};


//========================================================================
//========================================================================
//
// NAME: std::tuple<int,
//                  double,
//                  double> RPROP::minimize(std::function<double(const Vector<double> &)> f,
//                                          std::function<Vector<double>(const Vector<double> &)> df,
//                                          Vector<double> x)
//
// DESC: Minimize the function f, using RPROP.
//
// INPUT:
//     std::function<double(const Vector<double> &)> f : (only used to check convergence)
//     std::function<Vector<double>(const Vector<double> &)> df :
//     Vector<double> x :
//
// OUTPUT:
//     int info :
//     Vector<double> x :
//     double fx :
//
// NOTES:
//     ! This implementation is based on the standard RPROP used in machine learning.
//
//========================================================================
//========================================================================
inline std::tuple<
                  int,
                  Vector<double>,
                  double
                  > RPROP::minimize(
                                          std::function<double(const Vector<double> &)> f,
                                          std::function<Vector<double>(const Vector<double> &)> df,
                                          Vector<double> x
                                          )
{
    int info = 1;
    // Vector<double> x
    double fx;
    
    // calculate the initial function (fx), gradient (g), priors (xm1), and step sizes (delta)
    fx = f(x);
    Vector<double> g = df(x);
    
    double fxm1 = fx;
    Vector<double> gm1 = g;
    
    Vector<double> delta = this->fdelta*x;
    Vector<double> delta_max(x.size());
    
    // limit step sizes to the magnitude of the initial x
    for(auto i = 0; i < x.size(); ++i)
    {
        delta_max(i) = std::abs(x(i));
    }
    
/*
    for(auto i = 0; i < x.size(); ++i)
    {
        std::cout << "x(" << i << "): " << x(i) << '\n';
    }
    
    std::cout << '\n';
    
    for(auto i = 0; i < delta.size(); ++i)
    {
        std::cout << "delta(" << i << "): " << delta(i) << '\n';
    }
    
    std::cout << '\n';
*/
    
    for(int iter = 0; iter < this->iter_max; ++iter)
    {
        g = df(x);
        
        for(auto i = 0; i < g.size(); ++i)
        {
            double gm1_g = gm1(i)*g(i);
            
            if(gm1_g > 0.0)
            {
                delta(i) = std::min( (delta(i)*this->eta_p), delta_max(i) );
            }
            else if(gm1_g < 0.0)
            {
                // 0 is a lower bound, so extra care is not needed ...
                delta(i) *= this->eta_m;
            }
            /*
            else
            {
                do not scale 
            }
            */
        }
        
/*
        for(auto i = 0; i < delta.size(); ++i)
        {
            std::cout << "delta(" << i << "): " << delta(i) << '\n';
        }
        
        std::cout << '\n';
*/
 
        for(auto i = 0; i < x.size(); ++i)
        {
            x(i) -= sign(g(i))*delta(i);
            
//            std::cout << "x(" << i << "): " << x(i) << '\n';
        }
        
//        std::cout << '\n';
        
        // convergence criteria: check the change in the function f()
        fx = f(x);
        
        if((2.0*std::abs(fx - fxm1)) < (this->ftol*(std::abs(fx) + std::abs(fxm1))))
        {
            info = 0;
            
            break;
        }
        
        fxm1 = fx;
        gm1 = g;
    }
    
    return std::make_tuple(info, x, fx);
}

